package com.kool.kmqtt.server.repository.redis;

import com.alibaba.fastjson.JSON;
import com.kool.kmqtt.server.exception.AppException;
import com.kool.kmqtt.server.exception.ErrorCode;
import com.kool.kmqtt.server.log.ClientNoAckSendPacketCnt;
import com.kool.kmqtt.server.packet.Packet;
import com.kool.kmqtt.server.repository.Repository;
import com.kool.kmqtt.server.repository.subscription.SubscribeInfo;
import com.kool.kmqtt.server.repository.subscription.Subscriber;
import com.kool.kmqtt.server.repository.subscription.Subscription;
import com.kool.kmqtt.server.session.WillStatus;
import com.kool.kmqtt.util.StringUtil;
import lombok.extern.slf4j.Slf4j;

import java.util.*;
import java.util.stream.Collectors;

import static com.kool.kmqtt.util.TopicUtil.split;

/**
 * @author : luyu
 * @date :2021/3/18 19:17
 */
@Slf4j
public class RedisRepository implements Repository {
    private RedisClient redisClient = new RedisClient();

    @Override
    public void deleteSessionStatus(String clientId) {
        //删除客户端的所有订阅信息
        deleteSubscriber(clientId);
        //删除遗嘱信息
        deleteWillStatus(clientId);
        //删除出站未确认PUBLISH、PUBREL报文
        deleteSendPackets(clientId);
    }

    /**
     * 遗嘱信息KEY
     * 使用KV
     */
    public static final String WILL_STATUS_KEY_PREFIX = "mqtt:will_status:";

    @Override
    public void saveWillStatus(String clientId, WillStatus willStatus) {
        if (clientId == null) {
            throw new AppException(ErrorCode.SAVE_WILL_STATUS_CLIENT_ID_NULL);
        }
        redisClient.set(WILL_STATUS_KEY_PREFIX + clientId, JSON.toJSONString(willStatus));

    }

    @Override
    public void deleteWillStatus(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.DELETE_WILL_STATUS_CLIENT_ID_NULL);
        }
        redisClient.delete(WILL_STATUS_KEY_PREFIX + clientId);

    }

    @Override
    public WillStatus getWillStatus(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.DELETE_WILL_STATUS_CLIENT_ID_NULL);
        }
        String value = redisClient.get(WILL_STATUS_KEY_PREFIX + clientId);
        if (value == null) {
            return null;
        }
        return JSON.parseObject(value, WillStatus.class);
    }

    /**
     * 客户端的订阅信息KEY
     * 使用HASH
     * h:clientId
     * hk:topicFilter
     * hv:qos
     */
    public static final String SUBSCRIBE_INFO_KEY_PREFIX = "mqtt:subscribe_info:";
    /**
     * 主题过滤器的客户端信息KEY
     * 使用HASH
     * h:topicFilter
     * hk:clientId
     * hv:qos
     */
    public static final String SUBSCRIBER_KEY_PREFIX = "mqtt:subscriber:";

    /**
     * 查询所有订阅信息
     *
     * @return
     */
    @Override
    public List<Subscription> getSubscriptions() {
        List<Subscription> subscriptions = new ArrayList<>();
        Set<String> topicFilterSet = redisClient.scan(SUBSCRIBER_KEY_PREFIX + "*");
        if (topicFilterSet != null) {
            for (String topicFilter : topicFilterSet) {
                Map<String, String> subscribers = redisClient.getHashMap(topicFilter);
                if (subscribers != null) {
                    for (String clientId : subscribers.keySet()) {
                        Subscription subscription = new Subscription();
                        subscription.setClientId(clientId);
                        subscription.setTopicFilter(topicFilter);
                        subscription.setQos(Integer.parseInt(subscribers.get(clientId)));
                        subscriptions.add(subscription);
                    }
                }
            }
        }
        return subscriptions;
    }

    @Override
    public void saveSubscribeInfo(String clientId, SubscribeInfo subscribeInfo) {
        if (clientId == null) {
            throw new AppException(ErrorCode.DELETE_WILL_STATUS_CLIENT_ID_NULL);
        }
        if (subscribeInfo != null) {
            //保存客户端的订阅信息
            redisClient.putHash(SUBSCRIBE_INFO_KEY_PREFIX + clientId, subscribeInfo.getTopicFilter(), Integer.toString(subscribeInfo.getQos()));
            //保存主题过滤器的客户端信息
            redisClient.putHash(SUBSCRIBER_KEY_PREFIX + subscribeInfo.getTopicFilter(), clientId, Integer.toString(subscribeInfo.getQos()));
        }
    }

    @Override
    public List<Subscriber> getSubscriber(String topic) {
        if (topic == null) {
            throw new AppException(ErrorCode.TOPIC_NULL);
        }
        Map<String, Subscriber> subscribers = new HashMap<>();
        /**
         * key 的枚举算法：
         * 不包含“#”的主题过滤器编码算法：主题过滤器每个字符都是主题的字符或者‘+’，用二进制表示，0表示主题字符，1表示通配符'+'，
         * 假设主题层级n，所有0到2^n-1 二进制数按上述算法解码出不包含“#”的主题过滤器，
         * 则主题有2^n个不包含“#”的主题过滤器，
         * 对不包含“#”的主题过滤器拼上“/#”，又可以找出长度n+1的包含“#”的主题过滤器
         * 将主题尾部一层截去，剩下的层级所构成的主题的不包含“#”的主题过滤器拼上“/#”，又可以找出长度n的包含“#”的主题过滤器，
         * 继续循环截取，并找出包含“#”的主题过滤器，直到最后一层
         * 最后还有一个 “#”主题过滤器
         *
         * 举例： topic = "a/b/c"
         *
         * 不包含“#”的主题过滤器:
         * a/b/c
         * +/b/c
         * a/+/c
         * a/b/+
         * +/+/c
         * +/b/+
         * a/+/+
         * +/+/+
         * 包含“#”的主题过滤器：
         * a/b/c/#   a/b/#   a/#   #
         * +/b/c/#   +/b/#   +/#
         * a/+/c/#   a/+/#
         * a/b/+/#   +/+/#
         * +/+/c/#
         * +/b/+/#
         * a/+/+/#
         * +/+/+/#
         */
        String subTopic = topic;
        int topicN = split(topic).length;
        while (subTopic.length() > 0) {
            String[] topicSegs = split(subTopic);
            int n = topicSegs.length;
            //找出不包含“#”的主题过滤器
            for (int i = 0; i < (1 << n); i++) {
                StringBuilder topicFilter = new StringBuilder();
                for (int j = 0; j < n; j++) {
                    topicFilter.append((i >> (n - j - 1) & 0x01) == 0 ? topicSegs[j] : "+");
                    if (j < n - 1) {
                        topicFilter.append("/");
                    }
                }
                if (n == topicN) {
                    //获取不包含“#”的主题过滤器的客户端信息
                    Map<String, String> subscribersPart = redisClient.getHashMap(SUBSCRIBER_KEY_PREFIX + topicFilter.toString());
                    //用获取到的客户端信息组装订阅者信息
                    buildSubscribers(subscribers, subscribersPart);
                }
                //获取包含“#”主题过滤器的客户端信息
                Map<String, String> subscribersPart = redisClient.getHashMap(SUBSCRIBER_KEY_PREFIX + topicFilter.toString() + "/#");
                //用获取到的客户端信息组装订阅者信息
                buildSubscribers(subscribers, subscribersPart);
            }

            int lastSplitIndex = subTopic.lastIndexOf("/");
            if (lastSplitIndex >= 0) {
                subTopic = subTopic.substring(0, lastSplitIndex);
            } else {
                subTopic = "";
            }
        }
        //获取“#”主题过滤器的客户端信息
        Map<String, String> subscribersPart = redisClient.getHashMap(SUBSCRIBER_KEY_PREFIX + "#");
        //用获取到的客户端信息组装订阅者信息
        buildSubscribers(subscribers, subscribersPart);
        return new ArrayList<>(subscribers.values());
    }

    /**
     * 用获取到的客户端信息组装订阅者信息
     *
     * @param subscribers
     * @param subscribersPart
     */
    private void buildSubscribers(Map<String, Subscriber> subscribers, Map<String, String> subscribersPart) {
        if (subscribersPart != null) {
            for (String clientId : subscribersPart.keySet()) {
                int qos = Integer.parseInt(subscribersPart.get(clientId));
                Subscriber subscriber = subscribers.get(clientId);
                if (subscriber == null || subscriber.getQos() < qos) {
                    subscriber = new Subscriber();
                    subscriber.setClientId(clientId);
                    subscriber.setQos(qos);
                    subscribers.put(clientId, subscriber);
                }

            }
        }
    }

    @Override
    public void deleteSubscribeInfo(String clientId, String topicFilter) {
        if (clientId == null) {
            throw new AppException(ErrorCode.CLIENT_ID_NULL);
        }
        if (topicFilter != null) {
            redisClient.deleteHashValue(SUBSCRIBE_INFO_KEY_PREFIX + clientId, topicFilter);
            redisClient.deleteHashValue(SUBSCRIBER_KEY_PREFIX + topicFilter, clientId);
        }

    }

    @Override
    public void deleteSubscriber(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.CLIENT_ID_NULL);
        }
        //查找客户端的订阅信息
        Map<String, String> subscribeInfoMap = redisClient.getHashMap(SUBSCRIBE_INFO_KEY_PREFIX + clientId);
        //删除主题过滤器下的该客户端信息
        for (String topicFilter : subscribeInfoMap.keySet()) {
            redisClient.deleteHashValue(SUBSCRIBER_KEY_PREFIX + topicFilter, clientId);
        }
        //删除客户端的订阅信息
        redisClient.delete(SUBSCRIBE_INFO_KEY_PREFIX + clientId);
    }

    /**
     * 入站未确认消息KEY
     * 使用KV
     */
    public static final String RECEIVE_PACKET_KEY_PREFIX = "mqtt:receive_packet:";

    @Override
    public void saveReceivePacket(String clientId, Packet packet) {
        if (packet.getPacketId() == null) {
            throw new AppException(ErrorCode.SAVE_PACKET_ID_NULL);
        }
        redisClient.set(RECEIVE_PACKET_KEY_PREFIX + packet.getPacketId().toString(), JSON.toJSONString(packet));
    }

    @Override
    public Packet getReceivePacket(int packetId) {
        String value = redisClient.get(RECEIVE_PACKET_KEY_PREFIX + packetId);
        if (value != null) {
            Packet packet = JSON.parseObject(value, Packet.class);
            return packet;
        }
        return null;
    }

    @Override
    public void deleteReceivePacket(int packetId) {
        redisClient.delete(RECEIVE_PACKET_KEY_PREFIX + packetId);
    }

    /**
     * 保留消息KEY
     * 使用LIST
     */
    public static final String RETAIN_PACKET_KEY_PREFIX = "mqtt:retain_packet:";

    @Override
    public void saveRetainPacket(String topicName, Packet packet) {
        redisClient.push(RETAIN_PACKET_KEY_PREFIX + topicName, JSON.toJSONString(packet));
    }

    @Override
    public List<Packet> getRetainPacket(String topicFilter) {
        int flagIndex;
        //如果'#'结尾，上一层级作为前缀查询
        if (topicFilter.lastIndexOf("#") == topicFilter.length() - 1) {
            String keyProfix = null;
            if (topicFilter.equals("#")) {
                keyProfix = "";
            } else if (topicFilter.equals("/#")) {
                keyProfix = "/";
            } else if (topicFilter.length() > 2) {
                keyProfix = topicFilter.substring(0, topicFilter.length() - 2);
            }
            Set<String> keys = redisClient.scan(RETAIN_PACKET_KEY_PREFIX + keyProfix + "*");
            if (keys != null) {
                List<Packet> packets = new ArrayList<>();
                for (String key : keys) {
                    List<String> values = redisClient.getList(key);
                    return values.stream().map(item -> JSON.parseObject(item, Packet.class)).collect(Collectors.toList());
                }
                return packets;

            }
        }
        //取左边第一个'+'，如果存在，当前层级作为前缀查询，然后对查询出的KEY匹配后续层级
        else if ((flagIndex = topicFilter.indexOf("+")) >= 0) {
            String keyProfix = topicFilter.substring(0, flagIndex);
            Set<String> keys = redisClient.scan(RETAIN_PACKET_KEY_PREFIX + keyProfix + "*");
            if (keys != null) {
                List<Packet> packets = new ArrayList<>();
                for (String key : keys) {
                    String topic = key.substring(key.indexOf(":"));
                    if (match(topic, topicFilter)) {
                        List<String> values = redisClient.getList(key);
                        packets.addAll(values.stream().map(item -> JSON.parseObject(item, Packet.class)).collect(Collectors.toList()));
                    }
                }
                return packets;
            }
        }
        //如果没有通配符，直接精确查询
        else if (!topicFilter.contains("#") && !topicFilter.contains("+")) {
            List<String> values = redisClient.getList(RETAIN_PACKET_KEY_PREFIX + topicFilter);
            if (values != null) {
                return values.stream().map(item -> JSON.parseObject(item, Packet.class)).collect(Collectors.toList());
            }
        } else {
            throw new AppException(ErrorCode.TOPIC_FILTER_FORMAT_ERROR);
        }
        return null;
    }

    /**
     * 主题过滤器包含'+'通配符且不包含'#'通配符的情况，匹配主题和主题过滤器
     *
     * @param topic
     * @param topicFilter
     * @return
     */
    private boolean match(String topic, String topicFilter) {
        String[] topicItems = split(topic);
        String[] topicFilterItems = split(topicFilter);
        if (topicItems.length != topicFilterItems.length) {
            return false;
        }
        for (int i = 0; i < topicFilterItems.length; i++) {
            if (!topicItems[i].equals(topicFilterItems[i])
                    && !topicFilterItems[i].equals("+")) {
                return false;
            }
        }
        return true;
    }

    @Override
    public void deleteRetainPackets(String topicName) {
        redisClient.deleteList(RETAIN_PACKET_KEY_PREFIX + topicName);
    }

    /**
     * 未确认的出站PUBLISH或PUBREL报文KEY
     * 使用HASH
     */
    public static final String SEND_PACKET_KEY_PREFIX = "mqtt:send_packet:";

    @Override
    public void saveSendPackets(String clientId, Packet packet) {
        redisClient.putHash(SEND_PACKET_KEY_PREFIX + clientId, packet.getPacketId().toString(), JSON.toJSONString(packet));
    }

    @Override
    public void deleteSendPackets(String clientId) {
        redisClient.delete(SEND_PACKET_KEY_PREFIX + clientId);
    }

    @Override
    public List<Packet> getSendPackets(String clientId) {
        List<String> values = redisClient.getHashValues(SEND_PACKET_KEY_PREFIX + clientId);

        if (values != null) {
            //转报文时按报文发送时间时间戳倒序排序
            return values.stream()
                    .map(item -> JSON.parseObject(item, Packet.class))
                    .sorted(Comparator.comparing(Packet::getSendTime).reversed())
                    .collect(Collectors.toList());
        }
        return null;
    }

    @Override
    public void deleteSendPacket(String clientId, int packetId) {
        redisClient.deleteHashValue(SEND_PACKET_KEY_PREFIX + clientId, Integer.toString(packetId));
    }

    @Override
    public List<ClientNoAckSendPacketCnt> countSendPacket() {
        List<ClientNoAckSendPacketCnt> logs = new ArrayList<>();
        Set<String> keys = redisClient.scan(SEND_PACKET_KEY_PREFIX + "*");
        if (keys != null) {
            for (String key : keys) {
                String clientId = key.substring(SEND_PACKET_KEY_PREFIX.length());
                Map<String, String> packetMap = redisClient.getHashMap(key);
                ClientNoAckSendPacketCnt log = new ClientNoAckSendPacketCnt();
                log.setClientId(clientId);
                if (packetMap != null) {
                    log.setCnt(packetMap.size());
                } else {
                    log.setCnt(0);
                }
                logs.add(log);
            }
        }
        return logs;
    }

    /**
     * 客户端熔断计数器
     */
    private static final String FAILOVER_COUNTER_KEY_PREFIX = "mqtt:failover_counter:";

    @Override
    public Integer getClientConnectErrorTimes(String remoteAddress) {
        if (remoteAddress == null) {
            //取不到地址就不处理
            return null;
        }
        String counter = redisClient.get(FAILOVER_COUNTER_KEY_PREFIX + remoteAddress);
        if (StringUtil.isEmpty(counter)) {
            return 0;
        }

        return Integer.parseInt(counter);
    }

    @Override
    public void increaseFailoverCounter(String remoteAddress, int timeoutSeconds) {
        if (remoteAddress == null) {
            //取不到地址就不处理
            return;
        }
        String counter = redisClient.get(FAILOVER_COUNTER_KEY_PREFIX + remoteAddress);
        if (StringUtil.isEmpty(counter)) {
            redisClient.set(FAILOVER_COUNTER_KEY_PREFIX + remoteAddress, "1", timeoutSeconds);
        } else {
            redisClient.inrc(FAILOVER_COUNTER_KEY_PREFIX + remoteAddress);
        }

    }
}
